from typing import Dict, List

import numpy as np
from injector import inject
from sklearn.metrics.pairwise import cosine_similarity

from taskweaver.llm import LLMApi
from taskweaver.memory.plugin import PluginEntry, PluginRegistry


class SelectedPluginPool:
    def __init__(self):
        self.selected_plugin_pool = []
        self._previous_used_plugin_cache = []  # cache the plugins used in the previous code generation

    def add_selected_plugins(self, external_plugin_pool: List[PluginEntry]):
        """
        Add selected plugins to the pool
        """
        self.selected_plugin_pool = self.merge_plugin_pool(self.selected_plugin_pool, external_plugin_pool)

    def __len__(self) -> int:
        return len(self.selected_plugin_pool)

    def filter_unused_plugins(self, code: str):
        """
        Filter out plugins that are not used in the code generated by LLM
        """
        plugins_used_in_code = [p for p in self.selected_plugin_pool if p.name in code]
        self._previous_used_plugin_cache = self.merge_plugin_pool(
            self._previous_used_plugin_cache,
            plugins_used_in_code,
        )
        self.selected_plugin_pool = self._previous_used_plugin_cache

    def get_plugins(self) -> List[PluginEntry]:
        return self.selected_plugin_pool

    @staticmethod
    def merge_plugin_pool(pool1: List[PluginEntry], pool2: List[PluginEntry]) -> List[PluginEntry]:
        """
        Merge two plugin pools and remove duplicates
        """
        merged_list = pool1 + pool2
        result = []

        for item in merged_list:
            is_duplicate = False
            for existing_item in result:
                if item.name == existing_item.name:
                    is_duplicate = True
                    break
            if not is_duplicate:
                result.append(item)
        return result


class PluginSelector:
    @inject
    def __init__(
        self,
        plugin_registry: PluginRegistry,
        llm_api: LLMApi,
    ):
        self.plugin_registry = plugin_registry
        self.llm_api = llm_api
        self.plugin_embedding_dict: Dict[str, List[float]] = {}

    def generate_plugin_embeddings(self):
        plugin_intro_text_list: List[str] = []
        for p in self.plugin_registry.get_list():
            plugin_intro_text_list.append(p.name + ": " + p.spec.description)
        plugin_embeddings = self.llm_api.get_embedding_list(plugin_intro_text_list)
        for i, p in enumerate(self.plugin_registry.get_list()):
            self.plugin_embedding_dict[p.name] = plugin_embeddings[i]

    def plugin_select(self, user_query: str, top_k: int = 5) -> List[PluginEntry]:
        user_query_embedding = np.array(self.llm_api.get_embedding(user_query))

        similarities = []

        if top_k >= len(self.plugin_registry.get_list()):
            return self.plugin_registry.get_list()

        for p in self.plugin_registry.get_list():
            similarity = cosine_similarity(
                user_query_embedding.reshape(
                    1,
                    -1,
                ),
                np.array(self.plugin_embedding_dict[p.name]).reshape(1, -1),
            )
            similarities.append((p, similarity))

        plugins_rank = sorted(
            similarities,
            key=lambda x: x[1],
            reverse=True,
        )[:top_k]

        selected_plugins = [p for p, sim in plugins_rank]

        return selected_plugins
